# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You may not
# use this file except in compliance with the License. A copy of the License
# is located at
#
#     http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is distributed on
# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
# express or implied. See the License for the specific language governing
# permissions and limitations under the License.

from unittest import mock

import pytest

import sockeye.constants as C
from sockeye.train import fixed_param_names_from_strategy


NUM_LAYERS = 3

ALL_PARAMS = [
    'decoder.final_process.layer_norm.beta',
    'decoder.final_process.layer_norm.gamma',
    'decoder.layers.0.autoregr_layer.ff_in.weight',
    'decoder.layers.0.autoregr_layer.ff_out.weight',
    'decoder.layers.0.enc_attention.ff_kv.weight',
    'decoder.layers.0.enc_attention.ff_out.weight',
    'decoder.layers.0.enc_attention.ff_q.weight',
    'decoder.layers.0.ff.ff1.bias',
    'decoder.layers.0.ff.ff1.weight',
    'decoder.layers.0.ff.ff2.bias',
    'decoder.layers.0.ff.ff2.weight',
    'decoder.layers.0.pre_autoregr_layer.layer_norm.beta',
    'decoder.layers.0.pre_autoregr_layer.layer_norm.gamma',
    'decoder.layers.0.pre_enc_attention.layer_norm.beta',
    'decoder.layers.0.pre_enc_attention.layer_norm.gamma',
    'decoder.layers.0.pre_ff.layer_norm.beta',
    'decoder.layers.0.pre_ff.layer_norm.gamma',
    'decoder.layers.1.autoregr_layer.ff_in.weight',
    'decoder.layers.1.autoregr_layer.ff_out.weight',
    'decoder.layers.1.enc_attention.ff_kv.weight',
    'decoder.layers.1.enc_attention.ff_out.weight',
    'decoder.layers.1.enc_attention.ff_q.weight',
    'decoder.layers.1.ff.ff1.bias',
    'decoder.layers.1.ff.ff1.weight',
    'decoder.layers.1.ff.ff2.bias',
    'decoder.layers.1.ff.ff2.weight',
    'decoder.layers.1.pre_autoregr_layer.layer_norm.beta',
    'decoder.layers.1.pre_autoregr_layer.layer_norm.gamma',
    'decoder.layers.1.pre_enc_attention.layer_norm.beta',
    'decoder.layers.1.pre_enc_attention.layer_norm.gamma',
    'decoder.layers.1.pre_ff.layer_norm.beta',
    'decoder.layers.1.pre_ff.layer_norm.gamma',
    'decoder.layers.2.autoregr_layer.ff_in.weight',
    'decoder.layers.2.autoregr_layer.ff_out.weight',
    'decoder.layers.2.enc_attention.ff_kv.weight',
    'decoder.layers.2.enc_attention.ff_out.weight',
    'decoder.layers.2.enc_attention.ff_q.weight',
    'decoder.layers.2.ff.ff1.bias',
    'decoder.layers.2.ff.ff1.weight',
    'decoder.layers.2.ff.ff2.bias',
    'decoder.layers.2.ff.ff2.weight',
    'decoder.layers.2.pre_autoregr_layer.layer_norm.beta',
    'decoder.layers.2.pre_autoregr_layer.layer_norm.gamma',
    'decoder.layers.2.pre_enc_attention.layer_norm.beta',
    'decoder.layers.2.pre_enc_attention.layer_norm.gamma',
    'decoder.layers.2.pre_ff.layer_norm.beta',
    'decoder.layers.2.pre_ff.layer_norm.gamma',
    'embedding_source.factor1_weight',
    'embedding_source.weight',
    'embedding_target.factor1_weight',
    'embedding_target.weight',
    'encoder.final_process.layer_norm.beta',
    'encoder.final_process.layer_norm.gamma',
    'encoder.layers.0.ff.ff1.bias',
    'encoder.layers.0.ff.ff1.weight',
    'encoder.layers.0.ff.ff2.bias',
    'encoder.layers.0.ff.ff2.weight',
    'encoder.layers.0.pre_ff.layer_norm.beta',
    'encoder.layers.0.pre_ff.layer_norm.gamma',
    'encoder.layers.0.pre_self_attention.layer_norm.beta',
    'encoder.layers.0.pre_self_attention.layer_norm.gamma',
    'encoder.layers.0.self_attention.ff_in.weight',
    'encoder.layers.0.self_attention.ff_out.weight',
    'encoder.layers.1.ff.ff1.bias',
    'encoder.layers.1.ff.ff1.weight',
    'encoder.layers.1.ff.ff2.bias',
    'encoder.layers.1.ff.ff2.weight',
    'encoder.layers.1.pre_ff.layer_norm.beta',
    'encoder.layers.1.pre_ff.layer_norm.gamma',
    'encoder.layers.1.pre_self_attention.layer_norm.beta',
    'encoder.layers.1.pre_self_attention.layer_norm.gamma',
    'encoder.layers.1.self_attention.ff_in.weight',
    'encoder.layers.1.self_attention.ff_out.weight',
    'encoder.layers.2.ff.ff1.bias',
    'encoder.layers.2.ff.ff1.weight',
    'encoder.layers.2.ff.ff2.bias',
    'encoder.layers.2.ff.ff2.weight',
    'encoder.layers.2.pre_ff.layer_norm.beta',
    'encoder.layers.2.pre_ff.layer_norm.gamma',
    'encoder.layers.2.pre_self_attention.layer_norm.beta',
    'encoder.layers.2.pre_self_attention.layer_norm.gamma',
    'encoder.layers.2.self_attention.ff_in.weight',
    'encoder.layers.2.self_attention.ff_out.weight',
    'output_layer.bias',
    'output_layer.weight',
    'output_layer_factor1.bias',
    'output_layer_factor1.weight'
]


ALL_EXCEPT_DECODER_PARAMS = [
    'embedding_source.factor1_weight',
    'embedding_source.weight',
    'embedding_target.factor1_weight',
    'embedding_target.weight',
    'encoder.final_process.layer_norm.beta',
    'encoder.final_process.layer_norm.gamma',
    'encoder.layers.0.ff.ff1.bias',
    'encoder.layers.0.ff.ff1.weight',
    'encoder.layers.0.ff.ff2.bias',
    'encoder.layers.0.ff.ff2.weight',
    'encoder.layers.0.pre_ff.layer_norm.beta',
    'encoder.layers.0.pre_ff.layer_norm.gamma',
    'encoder.layers.0.pre_self_attention.layer_norm.beta',
    'encoder.layers.0.pre_self_attention.layer_norm.gamma',
    'encoder.layers.0.self_attention.ff_in.weight',
    'encoder.layers.0.self_attention.ff_out.weight',
    'encoder.layers.1.ff.ff1.bias',
    'encoder.layers.1.ff.ff1.weight',
    'encoder.layers.1.ff.ff2.bias',
    'encoder.layers.1.ff.ff2.weight',
    'encoder.layers.1.pre_ff.layer_norm.beta',
    'encoder.layers.1.pre_ff.layer_norm.gamma',
    'encoder.layers.1.pre_self_attention.layer_norm.beta',
    'encoder.layers.1.pre_self_attention.layer_norm.gamma',
    'encoder.layers.1.self_attention.ff_in.weight',
    'encoder.layers.1.self_attention.ff_out.weight',
    'encoder.layers.2.ff.ff1.bias',
    'encoder.layers.2.ff.ff1.weight',
    'encoder.layers.2.ff.ff2.bias',
    'encoder.layers.2.ff.ff2.weight',
    'encoder.layers.2.pre_ff.layer_norm.beta',
    'encoder.layers.2.pre_ff.layer_norm.gamma',
    'encoder.layers.2.pre_self_attention.layer_norm.beta',
    'encoder.layers.2.pre_self_attention.layer_norm.gamma',
    'encoder.layers.2.self_attention.ff_in.weight',
    'encoder.layers.2.self_attention.ff_out.weight',
    'output_layer.bias',
    'output_layer.weight',
    'output_layer_factor1.bias',
    'output_layer_factor1.weight'
]

ALL_EXCEPT_OUTER_LAYERS_PARAMS = [
    'decoder.final_process.layer_norm.beta',
    'decoder.final_process.layer_norm.gamma',
    'decoder.layers.1.autoregr_layer.ff_in.weight',
    'decoder.layers.1.autoregr_layer.ff_out.weight',
    'decoder.layers.1.enc_attention.ff_kv.weight',
    'decoder.layers.1.enc_attention.ff_out.weight',
    'decoder.layers.1.enc_attention.ff_q.weight',
    'decoder.layers.1.ff.ff1.bias',
    'decoder.layers.1.ff.ff1.weight',
    'decoder.layers.1.ff.ff2.bias',
    'decoder.layers.1.ff.ff2.weight',
    'decoder.layers.1.pre_autoregr_layer.layer_norm.beta',
    'decoder.layers.1.pre_autoregr_layer.layer_norm.gamma',
    'decoder.layers.1.pre_enc_attention.layer_norm.beta',
    'decoder.layers.1.pre_enc_attention.layer_norm.gamma',
    'decoder.layers.1.pre_ff.layer_norm.beta',
    'decoder.layers.1.pre_ff.layer_norm.gamma',
    'embedding_source.factor1_weight',
    'embedding_source.weight',
    'embedding_target.factor1_weight',
    'embedding_target.weight',
    'encoder.final_process.layer_norm.beta',
    'encoder.final_process.layer_norm.gamma',
    'encoder.layers.1.ff.ff1.bias',
    'encoder.layers.1.ff.ff1.weight',
    'encoder.layers.1.ff.ff2.bias',
    'encoder.layers.1.ff.ff2.weight',
    'encoder.layers.1.pre_ff.layer_norm.beta',
    'encoder.layers.1.pre_ff.layer_norm.gamma',
    'encoder.layers.1.pre_self_attention.layer_norm.beta',
    'encoder.layers.1.pre_self_attention.layer_norm.gamma',
    'encoder.layers.1.self_attention.ff_in.weight',
    'encoder.layers.1.self_attention.ff_out.weight',
    'output_layer.bias',
    'output_layer.weight',
    'output_layer_factor1.bias',
    'output_layer_factor1.weight'
]

ALL_EXCEPT_EMBED_PARAMS = [
    'decoder.final_process.layer_norm.beta',
    'decoder.final_process.layer_norm.gamma',
    'decoder.layers.0.autoregr_layer.ff_in.weight',
    'decoder.layers.0.autoregr_layer.ff_out.weight',
    'decoder.layers.0.enc_attention.ff_kv.weight',
    'decoder.layers.0.enc_attention.ff_out.weight',
    'decoder.layers.0.enc_attention.ff_q.weight',
    'decoder.layers.0.ff.ff1.bias',
    'decoder.layers.0.ff.ff1.weight',
    'decoder.layers.0.ff.ff2.bias',
    'decoder.layers.0.ff.ff2.weight',
    'decoder.layers.0.pre_autoregr_layer.layer_norm.beta',
    'decoder.layers.0.pre_autoregr_layer.layer_norm.gamma',
    'decoder.layers.0.pre_enc_attention.layer_norm.beta',
    'decoder.layers.0.pre_enc_attention.layer_norm.gamma',
    'decoder.layers.0.pre_ff.layer_norm.beta',
    'decoder.layers.0.pre_ff.layer_norm.gamma',
    'decoder.layers.1.autoregr_layer.ff_in.weight',
    'decoder.layers.1.autoregr_layer.ff_out.weight',
    'decoder.layers.1.enc_attention.ff_kv.weight',
    'decoder.layers.1.enc_attention.ff_out.weight',
    'decoder.layers.1.enc_attention.ff_q.weight',
    'decoder.layers.1.ff.ff1.bias',
    'decoder.layers.1.ff.ff1.weight',
    'decoder.layers.1.ff.ff2.bias',
    'decoder.layers.1.ff.ff2.weight',
    'decoder.layers.1.pre_autoregr_layer.layer_norm.beta',
    'decoder.layers.1.pre_autoregr_layer.layer_norm.gamma',
    'decoder.layers.1.pre_enc_attention.layer_norm.beta',
    'decoder.layers.1.pre_enc_attention.layer_norm.gamma',
    'decoder.layers.1.pre_ff.layer_norm.beta',
    'decoder.layers.1.pre_ff.layer_norm.gamma',
    'decoder.layers.2.autoregr_layer.ff_in.weight',
    'decoder.layers.2.autoregr_layer.ff_out.weight',
    'decoder.layers.2.enc_attention.ff_kv.weight',
    'decoder.layers.2.enc_attention.ff_out.weight',
    'decoder.layers.2.enc_attention.ff_q.weight',
    'decoder.layers.2.ff.ff1.bias',
    'decoder.layers.2.ff.ff1.weight',
    'decoder.layers.2.ff.ff2.bias',
    'decoder.layers.2.ff.ff2.weight',
    'decoder.layers.2.pre_autoregr_layer.layer_norm.beta',
    'decoder.layers.2.pre_autoregr_layer.layer_norm.gamma',
    'decoder.layers.2.pre_enc_attention.layer_norm.beta',
    'decoder.layers.2.pre_enc_attention.layer_norm.gamma',
    'decoder.layers.2.pre_ff.layer_norm.beta',
    'decoder.layers.2.pre_ff.layer_norm.gamma',
    'encoder.final_process.layer_norm.beta',
    'encoder.final_process.layer_norm.gamma',
    'encoder.layers.0.ff.ff1.bias',
    'encoder.layers.0.ff.ff1.weight',
    'encoder.layers.0.ff.ff2.bias',
    'encoder.layers.0.ff.ff2.weight',
    'encoder.layers.0.pre_ff.layer_norm.beta',
    'encoder.layers.0.pre_ff.layer_norm.gamma',
    'encoder.layers.0.pre_self_attention.layer_norm.beta',
    'encoder.layers.0.pre_self_attention.layer_norm.gamma',
    'encoder.layers.0.self_attention.ff_in.weight',
    'encoder.layers.0.self_attention.ff_out.weight',
    'encoder.layers.1.ff.ff1.bias',
    'encoder.layers.1.ff.ff1.weight',
    'encoder.layers.1.ff.ff2.bias',
    'encoder.layers.1.ff.ff2.weight',
    'encoder.layers.1.pre_ff.layer_norm.beta',
    'encoder.layers.1.pre_ff.layer_norm.gamma',
    'encoder.layers.1.pre_self_attention.layer_norm.beta',
    'encoder.layers.1.pre_self_attention.layer_norm.gamma',
    'encoder.layers.1.self_attention.ff_in.weight',
    'encoder.layers.1.self_attention.ff_out.weight',
    'encoder.layers.2.ff.ff1.bias',
    'encoder.layers.2.ff.ff1.weight',
    'encoder.layers.2.ff.ff2.bias',
    'encoder.layers.2.ff.ff2.weight',
    'encoder.layers.2.pre_ff.layer_norm.beta',
    'encoder.layers.2.pre_ff.layer_norm.gamma',
    'encoder.layers.2.pre_self_attention.layer_norm.beta',
    'encoder.layers.2.pre_self_attention.layer_norm.gamma',
    'encoder.layers.2.self_attention.ff_in.weight',
    'encoder.layers.2.self_attention.ff_out.weight',
    'output_layer.bias',
    'output_layer.weight',
    'output_layer_factor1.bias',
    'output_layer_factor1.weight'
]

ALL_EXCEPT_OUTPUT_PROJ_PARAMS = [
    'decoder.final_process.layer_norm.beta',
    'decoder.final_process.layer_norm.gamma',
    'decoder.layers.0.autoregr_layer.ff_in.weight',
    'decoder.layers.0.autoregr_layer.ff_out.weight',
    'decoder.layers.0.enc_attention.ff_kv.weight',
    'decoder.layers.0.enc_attention.ff_out.weight',
    'decoder.layers.0.enc_attention.ff_q.weight',
    'decoder.layers.0.ff.ff1.bias',
    'decoder.layers.0.ff.ff1.weight',
    'decoder.layers.0.ff.ff2.bias',
    'decoder.layers.0.ff.ff2.weight',
    'decoder.layers.0.pre_autoregr_layer.layer_norm.beta',
    'decoder.layers.0.pre_autoregr_layer.layer_norm.gamma',
    'decoder.layers.0.pre_enc_attention.layer_norm.beta',
    'decoder.layers.0.pre_enc_attention.layer_norm.gamma',
    'decoder.layers.0.pre_ff.layer_norm.beta',
    'decoder.layers.0.pre_ff.layer_norm.gamma',
    'decoder.layers.1.autoregr_layer.ff_in.weight',
    'decoder.layers.1.autoregr_layer.ff_out.weight',
    'decoder.layers.1.enc_attention.ff_kv.weight',
    'decoder.layers.1.enc_attention.ff_out.weight',
    'decoder.layers.1.enc_attention.ff_q.weight',
    'decoder.layers.1.ff.ff1.bias',
    'decoder.layers.1.ff.ff1.weight',
    'decoder.layers.1.ff.ff2.bias',
    'decoder.layers.1.ff.ff2.weight',
    'decoder.layers.1.pre_autoregr_layer.layer_norm.beta',
    'decoder.layers.1.pre_autoregr_layer.layer_norm.gamma',
    'decoder.layers.1.pre_enc_attention.layer_norm.beta',
    'decoder.layers.1.pre_enc_attention.layer_norm.gamma',
    'decoder.layers.1.pre_ff.layer_norm.beta',
    'decoder.layers.1.pre_ff.layer_norm.gamma',
    'decoder.layers.2.autoregr_layer.ff_in.weight',
    'decoder.layers.2.autoregr_layer.ff_out.weight',
    'decoder.layers.2.enc_attention.ff_kv.weight',
    'decoder.layers.2.enc_attention.ff_out.weight',
    'decoder.layers.2.enc_attention.ff_q.weight',
    'decoder.layers.2.ff.ff1.bias',
    'decoder.layers.2.ff.ff1.weight',
    'decoder.layers.2.ff.ff2.bias',
    'decoder.layers.2.ff.ff2.weight',
    'decoder.layers.2.pre_autoregr_layer.layer_norm.beta',
    'decoder.layers.2.pre_autoregr_layer.layer_norm.gamma',
    'decoder.layers.2.pre_enc_attention.layer_norm.beta',
    'decoder.layers.2.pre_enc_attention.layer_norm.gamma',
    'decoder.layers.2.pre_ff.layer_norm.beta',
    'decoder.layers.2.pre_ff.layer_norm.gamma',
    'embedding_source.factor1_weight',
    'embedding_source.weight',
    'embedding_target.factor1_weight',
    'embedding_target.weight',
    'encoder.final_process.layer_norm.beta',
    'encoder.final_process.layer_norm.gamma',
    'encoder.layers.0.ff.ff1.bias',
    'encoder.layers.0.ff.ff1.weight',
    'encoder.layers.0.ff.ff2.bias',
    'encoder.layers.0.ff.ff2.weight',
    'encoder.layers.0.pre_ff.layer_norm.beta',
    'encoder.layers.0.pre_ff.layer_norm.gamma',
    'encoder.layers.0.pre_self_attention.layer_norm.beta',
    'encoder.layers.0.pre_self_attention.layer_norm.gamma',
    'encoder.layers.0.self_attention.ff_in.weight',
    'encoder.layers.0.self_attention.ff_out.weight',
    'encoder.layers.1.ff.ff1.bias',
    'encoder.layers.1.ff.ff1.weight',
    'encoder.layers.1.ff.ff2.bias',
    'encoder.layers.1.ff.ff2.weight',
    'encoder.layers.1.pre_ff.layer_norm.beta',
    'encoder.layers.1.pre_ff.layer_norm.gamma',
    'encoder.layers.1.pre_self_attention.layer_norm.beta',
    'encoder.layers.1.pre_self_attention.layer_norm.gamma',
    'encoder.layers.1.self_attention.ff_in.weight',
    'encoder.layers.1.self_attention.ff_out.weight',
    'encoder.layers.2.ff.ff1.bias',
    'encoder.layers.2.ff.ff1.weight',
    'encoder.layers.2.ff.ff2.bias',
    'encoder.layers.2.ff.ff2.weight',
    'encoder.layers.2.pre_ff.layer_norm.beta',
    'encoder.layers.2.pre_ff.layer_norm.gamma',
    'encoder.layers.2.pre_self_attention.layer_norm.beta',
    'encoder.layers.2.pre_self_attention.layer_norm.gamma',
    'encoder.layers.2.self_attention.ff_in.weight',
    'encoder.layers.2.self_attention.ff_out.weight',
]


@pytest.mark.parametrize("param_names, strategy, expected_fixed_param_names", [
    (ALL_PARAMS, C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_DECODER, ALL_EXCEPT_DECODER_PARAMS),
    (ALL_PARAMS, C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_OUTER_LAYERS, ALL_EXCEPT_OUTER_LAYERS_PARAMS),
    (ALL_PARAMS, C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_EMBEDDINGS, ALL_EXCEPT_EMBED_PARAMS),
    (ALL_PARAMS, C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_OUTPUT_PROJ, ALL_EXCEPT_OUTPUT_PROJ_PARAMS),
])
def test_fixed_param_strategy(param_names, strategy, expected_fixed_param_names):
    config = mock.Mock()
    config.config_encoder.num_layers = NUM_LAYERS
    config.config_decoder.num_layers = NUM_LAYERS
    params = {name: None for name in ALL_PARAMS}
    fixed_param_names = fixed_param_names_from_strategy(config, params, strategy)
    assert sorted(fixed_param_names) == sorted(expected_fixed_param_names)
